{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "import torchvision\n",
    "import numpy as np\n",
    "from keras.preprocessing.image import load_img\n",
    "import torchvision.transforms as T\n",
    "\n",
    "\n",
    "model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n",
    "model.eval()\n",
    "\n",
    "img_path = '2007_000032.jpg'\n",
    "img = load_img(img_path, target_size=(800, 800))\n",
    "transform = T.Compose([T.ToTensor()])  # Defing PyTorch Transform\n",
    "img = transform(img)  # Apply the transform to the image\n",
    "\n",
    "images = [img]\n",
    "targets=None\n",
    "original_image_sizes = [img.shape[-2:] for img in images]\n",
    "images, targets = model.transform(images, targets)\n",
    "\n",
    "# 提取特征\n",
    "features = model.backbone(images.tensors)\n",
    "\n",
    "#提取proposals\n",
    "proposals, proposal_losses = model.rpn(images, features, targets)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以其中一个proposal [257.4335, 268.6282, 648.8815, 444.5958]为例，这个proposal缩放（stride=16）后是 [16.0896, 16.7893, 40.5551, 27.7872], 分别对应roi左上角的x，y坐标、roi右下角的x，y坐标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "rois_in_feature = torch.Tensor([[0, 16.0896, 16.7893, 40.5551, 27.7872]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "height = 50\n",
    "width = 50\n",
    "pooled_height = 7\n",
    "pooled_width = 7\n",
    "\n",
    "roi_start_w = 16.0896\n",
    "roi_start_h = 16.7893\n",
    "roi_end_w = 40.5551\n",
    "roi_end_h = 27.7872"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "roi_width = max((roi_end_w - roi_start_w), 1)\n",
    "roi_height = max((roi_end_h - roi_start_h), 1)\n",
    "bin_size_h = roi_height / pooled_height\n",
    "bin_size_w = roi_width / pooled_width"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "roi_bin_grid_h = 2\n",
    "roi_bin_grid_w = 2\n",
    "count = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "per_level_feature = features[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.1315, grad_fn=<DivBackward0>)\n",
      "tensor(0.3561, grad_fn=<DivBackward0>)\n",
      "tensor(0.1099, grad_fn=<DivBackward0>)\n",
      "tensor(0.0105, grad_fn=<DivBackward0>)\n",
      "tensor(0.1069, grad_fn=<DivBackward0>)\n",
      "tensor(0.4987, grad_fn=<DivBackward0>)\n",
      "tensor(-0.0353, grad_fn=<DivBackward0>)\n",
      "tensor(0.5771, grad_fn=<DivBackward0>)\n",
      "tensor(0.1737, grad_fn=<DivBackward0>)\n",
      "tensor(0.0971, grad_fn=<DivBackward0>)\n",
      "tensor(-0.0943, grad_fn=<DivBackward0>)\n",
      "tensor(-0.0527, grad_fn=<DivBackward0>)\n",
      "tensor(0.6988, grad_fn=<DivBackward0>)\n",
      "tensor(0.3200, grad_fn=<DivBackward0>)\n",
      "tensor(-0.1067, grad_fn=<DivBackward0>)\n",
      "tensor(-0.4572, grad_fn=<DivBackward0>)\n",
      "tensor(-0.3240, grad_fn=<DivBackward0>)\n",
      "tensor(-0.5574, grad_fn=<DivBackward0>)\n",
      "tensor(-0.3299, grad_fn=<DivBackward0>)\n",
      "tensor(0.5116, grad_fn=<DivBackward0>)\n",
      "tensor(0.3760, grad_fn=<DivBackward0>)\n",
      "tensor(-0.5320, grad_fn=<DivBackward0>)\n",
      "tensor(-0.6477, grad_fn=<DivBackward0>)\n",
      "tensor(-0.8813, grad_fn=<DivBackward0>)\n",
      "tensor(-0.8584, grad_fn=<DivBackward0>)\n",
      "tensor(-0.5922, grad_fn=<DivBackward0>)\n",
      "tensor(0.2852, grad_fn=<DivBackward0>)\n",
      "tensor(0.3046, grad_fn=<DivBackward0>)\n",
      "tensor(-0.3979, grad_fn=<DivBackward0>)\n",
      "tensor(-0.1485, grad_fn=<DivBackward0>)\n",
      "tensor(-0.5802, grad_fn=<DivBackward0>)\n",
      "tensor(-0.2516, grad_fn=<DivBackward0>)\n",
      "tensor(-0.3406, grad_fn=<DivBackward0>)\n",
      "tensor(0.3955, grad_fn=<DivBackward0>)\n",
      "tensor(0.3633, grad_fn=<DivBackward0>)\n",
      "tensor(0.1361, grad_fn=<DivBackward0>)\n",
      "tensor(-0.0142, grad_fn=<DivBackward0>)\n",
      "tensor(0.1989, grad_fn=<DivBackward0>)\n",
      "tensor(0.8769, grad_fn=<DivBackward0>)\n",
      "tensor(0.2482, grad_fn=<DivBackward0>)\n",
      "tensor(0.3156, grad_fn=<DivBackward0>)\n",
      "tensor(0.4014, grad_fn=<DivBackward0>)\n",
      "tensor(0.6164, grad_fn=<DivBackward0>)\n",
      "tensor(0.3802, grad_fn=<DivBackward0>)\n",
      "tensor(0.7231, grad_fn=<DivBackward0>)\n",
      "tensor(0.8693, grad_fn=<DivBackward0>)\n",
      "tensor(0.3949, grad_fn=<DivBackward0>)\n",
      "tensor(0.4931, grad_fn=<DivBackward0>)\n",
      "tensor(0.9423, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "for h in range(pooled_height): \n",
    "    for w in range(pooled_width):\n",
    "        res = 0\n",
    "        for i in range(2):\n",
    "            for j in range(2):\n",
    "                point_x = roi_start_w + w*bin_size_w + (bin_size_w*(1+2*i)/4)\n",
    "                point_y = roi_start_h + h*bin_size_h + (bin_size_h*(1+2*j)/4)\n",
    "        \n",
    "                val0 = per_level_feature[0][0][int(point_y)][int(point_x)]\n",
    "                val1 = per_level_feature[0][0][int(point_y)][int(point_x)+1]\n",
    "                val2 = per_level_feature[0][0][int(point_y)+1][int(point_x)]\n",
    "                val3 = per_level_feature[0][0][int(point_y)+1][int(point_x)+1]\n",
    "\n",
    "                x1 = point_x - int(point_x)\n",
    "                x2 = int(point_x) + 1 - point_x\n",
    "                y1 = point_y - int(point_y)\n",
    "                y2 = int(point_y) + 1 - point_y\n",
    "\n",
    "                area0 = x1 * y1\n",
    "                area1 = x2 * y1\n",
    "                area2 = x1 * y2\n",
    "                area3 = x2 * y2\n",
    "\n",
    "                res += (val0*area3 + val1*area2 + val2*area1 + val3*area0)\n",
    "\n",
    "        print(res/4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.ops import roi_align\n",
    "\n",
    "result = roi_align(\n",
    "                per_level_feature, rois_in_feature,\n",
    "                (7, 7), sampling_ratio=2\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1315,  0.3561,  0.1099,  0.0105,  0.1069,  0.4987, -0.0353],\n",
       "        [ 0.5771,  0.1737,  0.0971, -0.0943, -0.0527,  0.6988,  0.3200],\n",
       "        [-0.1067, -0.4572, -0.3240, -0.5574, -0.3299,  0.5116,  0.3760],\n",
       "        [-0.5320, -0.6477, -0.8813, -0.8584, -0.5922,  0.2852,  0.3046],\n",
       "        [-0.3979, -0.1485, -0.5802, -0.2516, -0.3406,  0.3955,  0.3633],\n",
       "        [ 0.1361, -0.0142,  0.1989,  0.8769,  0.2482,  0.3156,  0.4014],\n",
       "        [ 0.6164,  0.3802,  0.7231,  0.8693,  0.3949,  0.4931,  0.9423]],\n",
       "       grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result[0][0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
